import torch
import torch.nn as nn
import numpy as np


class ZhnNetHead(nn.Module):
    def __init__(self, train_backbone=False, device=torch.device('cpu')):
        super().__init__()
        self.convconf = nn.Conv2d(96, 1, kernel_size=3, padding=1, bias=True)
        self.convloc = nn.Conv2d(96, 4, kernel_size=3, padding=1, bias=True)
        self.train_backbone = train_backbone
        self.device = device

    def forward(self, x):  # Nx768x4x5
        conf = self.convconf(x)  # Nx1x4x5
        if self.train_backbone:
            x = self.convconf(x)
            x = x.flatten(1)  # Nx20
            x, _ = torch.max(x, dim=1)  # N
            x = torch.sigmoid(x)  # N
        else:
            loc = self.convloc(x)  # Nx4x4x5
            x = torch.cat((loc, conf), dim=1)  # Nx5x4x5
            x = self.decode_box(x)  # Nx4x5x5
        return x

    def decode_box(self, predict):
        """输出网络的实际预测结果
        :param predict:网络的输出:Nx5x4x5(批大小,预测结果,高,宽)
        预测结果为:横坐标,纵坐标,宽度,高度,置信度
        :return 相对每个网格的横纵坐标和宽高,再乘128为实际值
        """
        predict = predict.permute(0, 2, 3, 1)  # 批大小,高,宽,预测结果 N,H,W,P
        img_height, img_width = predict.shape[1], predict.shape[2]
        grid_x = torch.linspace(0, img_width - 1, img_width).repeat(img_height, 1)
        grid_y = torch.linspace(0, img_height - 1, img_height).repeat(img_width, 1).transpose(1, 0)
        if self.device == torch.device('cuda:0'):
            grid_x = grid_x.to(self.device)
            grid_y = grid_y.to(self.device)
        predict[..., 0] = torch.sigmoid(predict[..., 0]) + grid_x  # 横坐标
        predict[..., 1] = torch.sigmoid(predict[..., 1]) + grid_y  # 纵坐标
        predict[..., 2] = 1.62 * torch.exp(predict[..., 2])  # 宽度
        predict[..., 3] = 2.29 * torch.exp(predict[..., 3])  # 高度
        predict[..., 4] = torch.sigmoid(predict[..., 4])  # 置信度
        return predict
